fix: AccumulateGrad stream mismatch warning when using DDP with Fabric & Trainer#21746
Conversation
⚡ Required checks status: All passing 🟢Groups summary🟢 pytorch_lightning: Tests workflow
These checks are required after the changes to 🟢 fabric: Docs
These checks are required after the changes to 🟢 pytorch_lightning: Docs
These checks are required after the changes to 🟢 lightning_fabric: CPU workflow
These checks are required after the changes to 🟢 mypy
These checks are required after the changes to 🟢 install
These checks are required after the changes to Thank you for your contribution! 💜
|
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21746 +/- ##
=======================================
- Coverage 87% 87% -0%
=======================================
Files 270 270
Lines 23975 23987 +12
=======================================
+ Hits 20750 20755 +5
- Misses 3225 3232 +7 |
What does this PR do?
Fixes #21567
Problem
When using
Fabricwith DDP and gradient accumulation, the following warningwas emitted on every backward pass:
This warning did not appear when using plain PyTorch DDP directly.
Root Cause
The original
setup_moduleinitializedDistributedDataParallelinside anew side-stream context (
torch.cuda.Stream()):This was intentional for supporting CUDA graph whole-network capture, which
requires DDP to be initialized on a side-stream (see PyTorch docs:
https://docs.pytorch.org/docs/2.12/notes/cuda.html#id5).
However, for normal training (the vast majority of use cases), this causes a
stream mismatch: DDP registers its
AccumulateGradhooks on the side-streamduring initialization, but all subsequent forward/backward passes run on the
default stream. PyTorch detects this cross-stream node reference and emits
the warning.
Plain PyTorch DDP does not hit this because users initialize DDP directly
without any stream context, it defaults to the default stream, so there is
no mismatch.
Fix
Detect whether we are currently inside a CUDA graph capture context using
torch.cuda.is_current_stream_capturing(), and pick the appropriate stream:capturing=False): initialize DDP on the defaultstream. No mismatch, no warning.
capturing=True): initialize DDP on a newside-stream as before (required by PyTorch). Additionally suppress the
AccumulateGradwarning globally since the mismatch is intentional inthis context.
This fix is applied to both
DDPStrategyinlightning_fabricandDDPStrategyinpytorch_lightning.Testing
Verified with the reproduction script from #21567 across 4 GPUs (torch 2.11
and 2.12). Warning no longer appears under normal DDP + gradient accumulation
training.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist